Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GPU implementation of NMSv2 op #28745

Merged
merged 11 commits into from
Jun 10, 2019

Conversation

samikama
Copy link
Contributor

This PR adds a GPU implementation of NMSv2 Op. It also registers a FakeGPU op for CombinedNonMaxSuppression op to workaround issues encountered due to lack of GPU implementation until a proper GPU implementation can be done based on current GPU kernels.

…n of CombinedNonMaxSuppression op for Funcdef executions in TFTRT fallback path
@tensorflow-bot tensorflow-bot bot added the size:L CL Change Size: Large label May 15, 2019
@samikama
Copy link
Contributor Author

@tfboyd This is the first part of PRs that would improve performance on object detection networks.

@samikama
Copy link
Contributor Author

Test for new op is blocked by the #28744 since GPU tensors are not correctly transferred to host without it.

@aaroey aaroey removed the request for review from tatianashp May 15, 2019 20:52
@rthadur rthadur self-assigned this May 15, 2019
@rthadur rthadur added this to Assigned Reviewer in PR Queue via automation May 15, 2019
@rthadur rthadur added the comp:ops OPs related issues label May 15, 2019
samikama added a commit to samikama/tensorflow that referenced this pull request May 17, 2019
@rthadur rthadur requested a review from aaroey May 17, 2019 20:21
@samikama samikama mentioned this pull request May 20, 2019
@aaroey aaroey requested a review from chsigg May 22, 2019 19:45
@aaroey
Copy link
Member

aaroey commented May 22, 2019

Hi @chsigg, could you please help to take a look at this PR?
Thanks.

tensorflow/core/kernels/non_max_suppression_op.cu.cc Outdated Show resolved Hide resolved
void NMSKernel(const Box* d_desc_sorted_boxes, const int nboxes,
const float thresh, const int mask_ld, int* d_delete_mask,
bool flip_boxes = false) {
// Storing boxes used by this CUDA block in the shared memory
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comments should end with a '.'.

// One 1D line load the boxes for x-dimension
if (threadIdx.y == 0) {
const Box box = d_desc_sorted_boxes[i_to_load];
Box flipped = box;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would do this on 'box' directly with swap. It's unexpected to call this flipped when it's only flipped if flip_boxes is true.

__launch_bounds__(NMS_BLOCK_DIM* NMS_BLOCK_DIM, 4) __global__
void NMSKernel(const Box* d_desc_sorted_boxes, const int nboxes,
const float thresh, const int mask_ld, int* d_delete_mask,
bool flip_boxes = false) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefer no default arguments.

Would it help performance to make this a template parameter?

}
}
__syncthreads();
const int i = i_block_offset + threadIdx.x;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the same as i_to_load, no?

tensorflow/core/kernels/non_max_suppression_op.cu.cc Outdated Show resolved Hide resolved
// both take about the same time
int nto_copy = std::min(NMS_CHUNK_SIZE, N);
cudaEvent_t copy_done;
cudaEventCreate(&copy_done);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use stream_executor::Event

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@csigg I couldn't be able to find any examples of using stream_executor::Event in a similar fashion elsewhere in kernels. Even then, I don't think it is possible to implement the logic by stream_executor::Event since there is no mechanism equivalent to cudaEventSynchronize() implemented in the framework. I can try to spin on event::poll but that would be quite inefficient and would probably hinder the rest of the framework as well due to acquired locks. I would have preferred to use ThenExecute() chaining these but it would require all NMS ops to be converted to AsyncOps as well as a proper threadpool on event manager. Currently all events are executed on single thread and doing work there would block the event infrastructure. I can spawn the work on cpu device thread pool on the event callback but I am not sure if this level of complexity is justified.
How would you propose I would use stream_executor::Event, it is possible that I am missing something obvious.

explicit NonMaxSuppressionV2GPUOp(OpKernelConstruction* context)
: OpKernel(context) {}

void Compute(OpKernelContext* context) override {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comment what the implementation does. Inline (like, above sections below) is also fine.

tensorflow/core/kernels/non_max_suppression_op.cu.cc Outdated Show resolved Hide resolved
tensorflow/core/kernels/non_max_suppression_op.cu.cc Outdated Show resolved Hide resolved
@rthadur rthadur requested a review from chsigg May 29, 2019 05:52
tensorflow/core/kernels/non_max_suppression_op.cu.cc Outdated Show resolved Hide resolved
tensorflow/core/kernels/non_max_suppression_op.cu.cc Outdated Show resolved Hide resolved
tensorflow/core/kernels/non_max_suppression_op.cu.cc Outdated Show resolved Hide resolved
tensorflow/core/kernels/non_max_suppression_op.cu.cc Outdated Show resolved Hide resolved
tensorflow/core/kernels/non_max_suppression_op.cu.cc Outdated Show resolved Hide resolved
tensorflow/core/kernels/non_max_suppression_op.cu.cc Outdated Show resolved Hide resolved
tensorflow/core/kernels/non_max_suppression_op.cu.cc Outdated Show resolved Hide resolved
tensorflow/core/kernels/non_max_suppression_op.cu.cc Outdated Show resolved Hide resolved
tensorflow/core/kernels/non_max_suppression_op.cu.cc Outdated Show resolved Hide resolved
@tensorflow-bot tensorflow-bot bot added kokoro:force-run Tests on submitted change ready to pull PR ready for merge process labels Jun 4, 2019
@kokoro-team kokoro-team removed the kokoro:force-run Tests on submitted change label Jun 4, 2019
@aaroey
Copy link
Member

aaroey commented Jun 7, 2019

Sorry for the delay, there are some internal test failures and I'm still trying to fix them.

@tensorflow-copybara tensorflow-copybara merged commit 7f78ad5 into tensorflow:master Jun 10, 2019
PR Queue automation moved this from Approved by Reviewer to Merged Jun 10, 2019
tensorflow-copybara pushed a commit that referenced this pull request Jun 10, 2019
PiperOrigin-RevId: 252461000
pooyadavoodi pushed a commit to pooyadavoodi/tensorflow that referenced this pull request Jul 16, 2019
pooyadavoodi pushed a commit to pooyadavoodi/tensorflow that referenced this pull request Jul 16, 2019
pooyadavoodi pushed a commit to pooyadavoodi/tensorflow that referenced this pull request Jul 16, 2019
@ppwwyyxx
Copy link
Contributor

I believe this implementation is wrong: it does not agree with the CPU version of NMS Op.

In this implementation, when computing the area in IOU, it uses (x2-x1+1)*(y2-y1+1), as can be seen at:

const float w = fdimf(xx2 + 1.0f, xx1);
const float h = fdimf(yy2 + 1.0f, yy1);
const float intersection = w * h;

However, in the CPU version, it uses (x2-x1)*(y2-y1), as can be seen at:

const T intersection_area =
std::max<T>(intersection_ymax - intersection_ymin, static_cast<T>(0.0)) *
std::max<T>(intersection_xmax - intersection_xmin, static_cast<T>(0.0));
return intersection_area / (area_i + area_j - intersection_area);

For many inputs this may not have an effect at all. But for certain inputs the two versions will produce inconsistent results.

If your goal is to run object detection models, note that the "+1" is a legacy issue and we're trying to avoid the version with "+1" in Facebook. See this PR that handles "+1" in caffe2.

@aaroey
Copy link
Member

aaroey commented Jul 18, 2019

@samikama would you please fix the issue @ppwwyyxx mentioned above?

@SpaceInvader61
Copy link

SpaceInvader61 commented Jul 19, 2019

@samikama I tried to use your kernel from inside python and I am getting a segmentation fault by running this simple script:

import tensorflow as tf

tf.enable_eager_execution()

from tensorflow.python.ops import gen_image_ops

with tf.device("/device:GPU:0"):
    boxes = tf.constant([[1.0, 1.0, 1.0, 1.0]], dtype=tf.float32)
    scores = tf.constant([1.0], dtype=tf.float32)
    max_output_size = tf.constant(10, dtype=tf.int32)
    iou_threshold = tf.constant(0.7, dtype=tf.float32)
    score_threshold = tf.constant(float('-inf'), dtype=tf.float32)
    print("Start")
    x = gen_image_ops.non_max_suppression_v2(boxes, scores, max_output_size, iou_threshold, score_threshold)
    print("End")
    print(x)
docker run --runtime=nvidia -it -v $PWD:/tf -w /tf tensorflow/tensorflow:nightly-gpu-py3
 python pyscript.py

The output is

Start
Segmentation fault

Correct me if I am doing smth wrong

@samikama
Copy link
Contributor Author

samikama commented Jul 20, 2019

@ppwwyyxx Thanks for catching that. I made the fixes to support both legacy case and CPU identical implementation in #30893.
@SpaceInvader61 It looks like we missed that some input tensors need to be host tensors when they are changed from attributes to tensors. Can you try with #30893? Also your example is probably not really making use of the nms_v2 since the signature of nms_v2 is

non_max_suppression_v2(boxes, scores, max_output_size, iou_threshold, name=None)

Another point is you are passing a box with 0 surface area and that is the only box. Even though there is a single box test in the test suite, we didn't have an invalid box test. I will add the fix for it in an upcoming PR.

@SpaceInvader61
Copy link

@samikama with .HostMemory() it works, thank you so much! 👍

@AkshayKhatriKodiak
Copy link

Has this made it into any of the tensorflow releases? As far as I know, it wasn't included in 1.13 and 1.14. How about in tensorflow 2.0?

pooyadavoodi pushed a commit to pooyadavoodi/tensorflow that referenced this pull request Nov 5, 2019
DEKHTIARJonathan pushed a commit to DEKHTIARJonathan/tensorflow that referenced this pull request Mar 17, 2020
DEKHTIARJonathan pushed a commit to DEKHTIARJonathan/tensorflow that referenced this pull request Mar 17, 2020
DEKHTIARJonathan pushed a commit to DEKHTIARJonathan/tensorflow that referenced this pull request Mar 17, 2020
DEKHTIARJonathan pushed a commit to DEKHTIARJonathan/tensorflow that referenced this pull request May 26, 2020
DEKHTIARJonathan pushed a commit to DEKHTIARJonathan/tensorflow that referenced this pull request May 26, 2020
DEKHTIARJonathan pushed a commit to DEKHTIARJonathan/tensorflow that referenced this pull request Jul 9, 2020
DEKHTIARJonathan pushed a commit to DEKHTIARJonathan/tensorflow that referenced this pull request Oct 2, 2020
DEKHTIARJonathan pushed a commit to DEKHTIARJonathan/tensorflow that referenced this pull request Oct 2, 2020
DEKHTIARJonathan pushed a commit to DEKHTIARJonathan/tensorflow that referenced this pull request Oct 2, 2020
nouiz pushed a commit to nouiz/tensorflow that referenced this pull request Dec 14, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes comp:ops OPs related issues ready to pull PR ready for merge process size:L CL Change Size: Large
Projects
PR Queue
  
Merged
Development

Successfully merging this pull request may close these issues.

None yet

10 participants